{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7256d02",
   "metadata": {},
   "source": [
    "# Interactive posterior predictives checks\n",
    "\n",
    "This notebook demonstrates how to interactively examine model priors using [ipywidgets](https://ipywidgets.readthedocs.io/en/stable/).\n",
    "\n",
    "⚠️ This notebook is intended to be run interactively. Please run locally or [Open in Colab](https://colab.research.google.com/github/pyro-ppl/pyro/blob/dev/tutorial/source/prior_predictive.ipynb).\n",
    "\n",
    "The first step in [Bayesian workflow](https://arxiv.org/abs/2011.01808) is to create a model. The second step is to check prior samples from the model. This notebook shows how to interactively check prior samples and tune parameters of the top level prior distribution while visualizing model outputs.\n",
    "\n",
    "#### Summary\n",
    "\n",
    "- Wrap your model in a plotting function.\n",
    "- Use [ipywidgets.interact()](https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html) to create sliders for each parameter of your prior.\n",
    "- For expensive models, use a [Resampler](https://docs.pyro.ai/en/stable/infer.util.html#pyro.infer.resampler.Resampler)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0b4f8ff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q pyro-ppl  # for colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5184344d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from ipywidgets import interact, FloatSlider\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "import pyro.poutine as poutine\n",
    "from pyro.infer.resampler import Resampler\n",
    "\n",
    "assert pyro.__version__.startswith('1.9.1')\n",
    "smoke_test = ('CI' in os.environ)  # for CI testing only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e10c32bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(T: int = 1000, data=None):\n",
    "    # Sample parameters from the prior.\n",
    "    df = pyro.sample(\"df\", dist.LogNormal(0, 1))\n",
    "    p_scale = pyro.sample(\"p_scale\", dist.LogNormal(0, 1))  # process noise\n",
    "    m_scale = pyro.sample(\"m_scale\", dist.LogNormal(0, 1))  # measurement noise\n",
    "    \n",
    "    # Simulate a time series.\n",
    "    with pyro.plate(\"dt\", T):\n",
    "        process_noise = pyro.sample(\"process_noise\", dist.StudentT(df, 0, p_scale))\n",
    "    trend = pyro.deterministic(\"trend\", process_noise.cumsum(-1))\n",
    "    with pyro.plate(\"t\", T):\n",
    "        return pyro.sample(\"obs\", dist.Normal(trend, m_scale), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03d29ed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectory(df=1.0, p_scale=1.0, m_scale=1.0):\n",
    "    pyro.set_rng_seed(12345)\n",
    "    data = {\n",
    "        \"df\": torch.as_tensor(df),\n",
    "        \"p_scale\": torch.as_tensor(p_scale),\n",
    "        \"m_scale\": torch.as_tensor(m_scale),\n",
    "    }\n",
    "    trajectory = poutine.condition(model, data)()\n",
    "    plt.figure(figsize=(8, 4)).patch.set_color(\"white\")\n",
    "    plt.plot(trajectory)\n",
    "    plt.xlabel(\"time\")\n",
    "    plt.ylabel(\"obs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b0348ce",
   "metadata": {},
   "source": [
    "Now we can examine what model trajectories look like for particular values of top level latent variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "18c11eb3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce0e5489600743b4b0c8c371ee95186c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=1.0, description='df', max=10.0, min=0.01), FloatSlider(value=0.1, des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "interact(\n",
    "    plot_trajectory,\n",
    "    df=FloatSlider(value=1.0, min=0.01, max=10.0),\n",
    "    p_scale=FloatSlider(value=0.1, min=0.01, max=1.0),\n",
    "    m_scale=FloatSlider(value=1.0, min=0.01, max=10.0),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b71dd6e",
   "metadata": {},
   "source": [
    "But to tune the parameters of our priors, we'd like to look at an ensemble of trajectories each of whose top-level parameters is sampled from the current prior.  Let's rewrite our model so we can input the prior parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "776d52f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model2(T: int = 1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):\n",
    "    # Sample parameters from the prior.\n",
    "    df = pyro.sample(\"df\", dist.LogNormal(df0, df1))\n",
    "    p_scale = pyro.sample(\"p_scale\", dist.LogNormal(p0, p1))  # process noise\n",
    "    m_scale = pyro.sample(\"m_scale\", dist.LogNormal(m0, m1))  # measurement noise\n",
    "    \n",
    "    # Simulate a time series.\n",
    "    with pyro.plate(\"dt\", T):\n",
    "        process_noise = pyro.sample(\"process_noise\", dist.StudentT(df, 0, p_scale))\n",
    "    trend = pyro.deterministic(\"trend\", process_noise.cumsum(-1))\n",
    "    with pyro.plate(\"t\", T):\n",
    "        return pyro.sample(\"obs\", dist.Normal(trend, m_scale), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2daa8449",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectories(**kwargs):\n",
    "    pyro.set_rng_seed(12345)\n",
    "    with pyro.plate(\"trajectories\", 20, dim=-2):\n",
    "        trajectories = model2(**kwargs)\n",
    "    plt.figure(figsize=(8, 5)).patch.set_color(\"white\")\n",
    "    plt.plot(trajectories.T)\n",
    "    plt.xlabel(\"time\")\n",
    "    plt.ylabel(\"obs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "78a75ad1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36f97b3bc75941ebb9a00556b948e0d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "interact(\n",
    "    plot_trajectories,\n",
    "    df0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    df1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    "    p0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    p1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    "    m0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    m1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e663568",
   "metadata": {},
   "source": [
    "Yikes! It looks like our initial priors generated very weird trajectories, but we can slide to find better priors.  Try increasing `df0`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e9bacfa",
   "metadata": {},
   "source": [
    "## Resampler\n",
    "\n",
    "For more expensive simulations, sampling may be too slow to interactively generate samples at each change. As a computational trick we can draw many samples once from a diffuse distribution, then resample them from a modified distribution -- provided we importance sample or resample. Pyro provides an importance [Resampler](https://docs.pyro.ai/en/stable/infer.util.html#pyro.infer.resampler.Resampler) to aid in interactively visualizing expensive models.\n",
    "\n",
    "We'll start with our original model and create a way to make parametrized partial models with given priors. These partial models are just the top half our our model, the top level parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "41b808de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_partial_model(df0, df1, p0, p1, m0, m1):\n",
    "    def partial_model():\n",
    "        # Sample parameters from the prior.\n",
    "        pyro.sample(\"df\", dist.LogNormal(df0, df1))\n",
    "        pyro.sample(\"p_scale\", dist.LogNormal(p0, p1))  # process noise\n",
    "        pyro.sample(\"m_scale\", dist.LogNormal(m0, m1))  # measurement noise\n",
    "    return partial_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7acbde94",
   "metadata": {},
   "source": [
    "Next we'll initialize the `Resampler` with a diffuse guide that covers most of our desired parameter space. This can be expensive in real simulations, so you might want to run it overnight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "74a685c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 940 ms, sys: 146 ms, total: 1.09 s\n",
      "Wall time: 934 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "partial_guide = make_partial_model(0, 10, 0, 10, 0, 10)\n",
    "resampler = Resampler(partial_guide, model, num_guide_samples=10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "785c131c",
   "metadata": {},
   "source": [
    "The `Resampler.sample()` method takes a modified partial model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c939c991",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_resampled(df0, df1, p0, p1, m0, m1):\n",
    "    partial_model = make_partial_model(df0, df1, p0, p1, m0, m1)\n",
    "    samples = resampler.sample(partial_model, num_samples=20)\n",
    "    trajectories = samples[\"obs\"]\n",
    "    plt.figure(figsize=(8, 5)).patch.set_color(\"white\")\n",
    "    plt.plot(trajectories.T)\n",
    "    plt.xlabel(\"time\")\n",
    "    plt.ylabel(\"obs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0fe987b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3219e20ecdff4b6892b50894fc76be1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "interact(\n",
    "    plot_resampled,\n",
    "    df0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    df1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    "    p0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    p1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    "    m0=FloatSlider(value=0.0, min=-5, max=5),\n",
    "    m1=FloatSlider(value=1.0, min=0.1, max=10),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7107e15",
   "metadata": {},
   "source": [
    "After deciding on good prior parameters, we can then hard-code those into the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "112e9537",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(T: int = 1000, data=None):\n",
    "    df = pyro.sample(\"df\", dist.LogNormal(4, 1))  # <-- changed 0 to 4\n",
    "    p_scale = pyro.sample(\"p_scale\", dist.LogNormal(1, 1))  # <-- changed 0 to 1\n",
    "    m_scale = pyro.sample(\"m_scale\", dist.LogNormal(0, 1))\n",
    "\n",
    "    with pyro.plate(\"dt\", T):\n",
    "        process_noise = pyro.sample(\"process_noise\", dist.StudentT(df, 0, p_scale))\n",
    "    trend = pyro.deterministic(\"trend\", process_noise.cumsum(-1))\n",
    "    with pyro.plate(\"t\", T):\n",
    "        return pyro.sample(\"obs\", dist.Normal(trend, m_scale), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e4a4a6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
